In [1]:
from collections import Counter
from tqdm import tqdm
import pandas as pd
import numpy as np
import plotly.express as px

from imblearn.datasets import fetch_datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
# from tabpfn import TabPFNClassifier
In [2]:
datasets = fetch_datasets()
datasets
Out[2]:
OrderedDict([('ecoli',
              {'data': array([[0.49, 0.29, 0.48, ..., 0.56, 0.24, 0.35],
                      [0.07, 0.4 , 0.48, ..., 0.54, 0.35, 0.44],
                      [0.56, 0.4 , 0.48, ..., 0.49, 0.37, 0.46],
                      ...,
                      [0.61, 0.6 , 0.48, ..., 0.44, 0.39, 0.38],
                      [0.59, 0.61, 0.48, ..., 0.42, 0.42, 0.37],
                      [0.74, 0.74, 0.48, ..., 0.31, 0.53, 0.52]]),
               'target': array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
                       1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
                       1,  1,  1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]),
               'DESCR': 'ecoli'}),
             ('optical_digits',
              {'data': array([[ 0.,  1.,  6., ...,  1.,  0.,  0.],
                      [ 0.,  0., 10., ...,  3.,  0.,  0.],
                      [ 0.,  0.,  8., ...,  0.,  0.,  0.],
                      ...,
                      [ 0.,  0.,  1., ...,  6.,  0.,  0.],
                      [ 0.,  0.,  2., ..., 12.,  0.,  0.],
                      [ 0.,  0., 10., ..., 12.,  1.,  0.]]),
               'target': array([-1, -1, -1, ...,  1, -1,  1]),
               'DESCR': 'optical_digits'}),
             ('satimage',
              {'data': array([[ 92., 115., 120., ..., 107., 113.,  87.],
                      [ 84., 102., 106., ...,  99., 104.,  79.],
                      [ 84., 102., 102., ...,  99., 104.,  79.],
                      ...,
                      [ 56.,  68.,  91., ...,  83.,  92.,  74.],
                      [ 56.,  68.,  87., ...,  83.,  92.,  70.],
                      [ 60.,  71.,  91., ...,  79., 108.,  92.]]),
               'target': array([-1, -1, -1, ..., -1, -1, -1]),
               'DESCR': 'satimage'}),
             ('pen_digits',
              {'data': array([[ 47., 100.,  27., ...,  90.,  40.,  98.],
                      [  0.,  89.,  27., ...,   2., 100.,   6.],
                      [  0.,  57.,  31., ...,  25.,  16.,   0.],
                      ...,
                      [ 56., 100.,  27., ...,  93.,  38.,  93.],
                      [ 19., 100.,   0., ...,  97.,  10.,  81.],
                      [ 38., 100.,  37., ...,  26.,  65.,   0.]]),
               'target': array([-1, -1, -1, ..., -1, -1, -1]),
               'DESCR': 'pen_digits'}),
             ('abalone',
              {'data': array([[0.    , 0.    , 1.    , ..., 0.2245, 0.101 , 0.15  ],
                      [0.    , 0.    , 1.    , ..., 0.0995, 0.0485, 0.07  ],
                      [1.    , 0.    , 0.    , ..., 0.2565, 0.1415, 0.21  ],
                      ...,
                      [0.    , 0.    , 1.    , ..., 0.5255, 0.2875, 0.308 ],
                      [1.    , 0.    , 0.    , ..., 0.531 , 0.261 , 0.296 ],
                      [0.    , 0.    , 1.    , ..., 0.9455, 0.3765, 0.495 ]]),
               'target': array([-1,  1, -1, ..., -1, -1, -1]),
               'DESCR': 'abalone'}),
             ('sick_euthyroid',
              {'data': array([[ 72.,   0.,   1., ...,  87.,   1.,   0.],
                      [ 45.,   1.,   0., ..., 112.,   1.,   0.],
                      [ 64.,   1.,   0., ..., 123.,   1.,   0.],
                      ...,
                      [ 58.,   1.,   0., ...,  95.,   1.,   0.],
                      [ 29.,   1.,   0., ...,  98.,   1.,   0.],
                      [ 56.,   1.,   0., ..., 143.,   1.,   0.]]),
               'target': array([ 1,  1,  1, ..., -1, -1, -1]),
               'DESCR': 'sick_euthyroid'}),
             ('spectrometer',
              {'data': array([[ 4119.1675 ,  4897.299  ,  4163.969  , ...,  1392.4745 ,
                        1278.9945 ,  1440.482  ],
                      [ 7660.999  ,  7906.784  ,  7821.8984 , ...,  7015.5747 ,
                        6962.22   ,  6263.44   ],
                      [ 3196.4287 ,  3013.9722 ,  3003.149  , ...,  5954.388  ,
                        5337.8887 ,  4638.5244 ],
                      ...,
                      [15375.604  , 14542.233  , 13849.163  , ...,   824.46655,
                         703.36536,   649.1576 ],
                      [11814.46   , 12896.945  , 13033.315  , ...,  1677.9924 ,
                        1601.7352 ,  1470.1552 ],
                      [14268.037  , 12925.045  , 12433.298  , ...,   702.66327,
                         681.01196,   670.62964]]),
               'target': array([-1,  1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1,
                      -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1,
                       1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1, -1, -1,
                      -1, -1, -1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1,
                       1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1,  1,  1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
                       1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,
                      -1, -1, -1, -1]),
               'DESCR': 'spectrometer'}),
             ('car_eval_34',
              {'data': array([[0., 0., 0., ..., 0., 1., 0.],
                      [0., 0., 0., ..., 0., 0., 1.],
                      [0., 0., 0., ..., 1., 0., 0.],
                      ...,
                      [0., 1., 0., ..., 0., 1., 0.],
                      [0., 1., 0., ..., 0., 0., 1.],
                      [0., 1., 0., ..., 1., 0., 0.]]),
               'target': array([-1, -1, -1, ..., -1,  1,  1]),
               'DESCR': 'car_eval_34'}),
             ('isolet',
              {'data': array([[-0.4394, -0.093 ,  0.1718, ...,  0.641 ,  0.5898, -0.4872],
                      [-0.4348, -0.1198,  0.2474, ...,  0.4318,  0.4546, -0.091 ],
                      [-0.233 ,  0.2124,  0.5014, ...,  0.254 ,  0.1588, -0.4762],
                      ...,
                      [-0.6696, -0.373 ,  0.1584, ...,  0.0728,  0.0728, -0.5818],
                      [-0.5764, -0.1764,  0.5106, ...,  0.3044, -0.0434, -0.5   ],
                      [-0.6624, -0.3334,  0.3666, ..., -0.0894, -0.1708, -0.317 ]]),
               'target': array([ 1,  1,  1, ..., -1, -1, -1]),
               'DESCR': 'isolet'}),
             ('us_crime',
              {'data': array([[0.19, 0.33, 0.02, ..., 0.26, 0.2 , 0.32],
                      [0.  , 0.16, 0.12, ..., 0.12, 0.45, 0.  ],
                      [0.  , 0.42, 0.49, ..., 0.21, 0.02, 0.  ],
                      ...,
                      [0.16, 0.37, 0.25, ..., 0.32, 0.18, 0.91],
                      [0.08, 0.51, 0.06, ..., 0.38, 0.33, 0.22],
                      [0.2 , 0.78, 0.14, ..., 0.3 , 0.05, 1.  ]]),
               'target': array([-1,  1, -1, ..., -1, -1, -1]),
               'DESCR': 'us_crime'}),
             ('yeast_ml8',
              {'data': array([[ 0.0937  ,  0.139771,  0.062774, ..., -0.042402,  0.118473,
                        0.125632],
                      [-0.022711, -0.050504, -0.035691, ..., -0.014191,  0.022783,
                        0.123785],
                      [-0.090407,  0.021198,  0.208712, ..., -0.063378, -0.084181,
                       -0.034402],
                      ...,
                      [ 0.2416  ,  0.127602, -0.033072, ..., -0.038713, -0.026947,
                        0.00562 ],
                      [ 0.097274,  0.088109,  0.161101, ..., -0.019985,  0.280843,
                        0.143382],
                      [-0.001043,  0.030495,  0.007199, ...,  0.006505, -0.041307,
                       -0.146233]]),
               'target': array([-1, -1, -1, ..., -1, -1, -1]),
               'DESCR': 'yeast_ml8'}),
             ('scene',
              {'data': array([[0.646467  , 0.666435  , 0.685047  , ..., 0.247298  , 0.0140249 ,
                       0.0297093 ],
                      [0.770156  , 0.767255  , 0.761053  , ..., 0.137833  , 0.0826722 ,
                       0.0363203 ],
                      [0.793984  , 0.772096  , 0.76182   , ..., 0.0511252 , 0.112506  ,
                       0.0839236 ],
                      ...,
                      [0.952281  , 0.944987  , 0.905556  , ..., 0.0319002 , 0.0175471 ,
                       0.0197344 ],
                      [0.88399   , 0.899004  , 0.901019  , ..., 0.256158  , 0.226332  ,
                       0.22307   ],
                      [0.974915  , 0.866425  , 0.818144  , ..., 0.0051313 , 0.0250591 ,
                       0.00403332]]),
               'target': array([ 1,  1, -1, ..., -1, -1, -1]),
               'DESCR': 'scene'}),
             ('libras_move',
              {'data': array([[0.79691, 0.38194, 0.79691, ..., 0.3125 , 0.6383 , 0.29398],
                      [0.67892, 0.27315, 0.68085, ..., 0.69213, 0.17215, 0.69213],
                      [0.72147, 0.23611, 0.7234 , ..., 0.2662 , 0.78143, 0.27778],
                      ...,
                      [0.61122, 0.75926, 0.61122, ..., 0.52083, 0.44487, 0.5162 ],
                      [0.65957, 0.79167, 0.65764, ..., 0.52546, 0.54159, 0.52083],
                      [0.64023, 0.71991, 0.64217, ..., 0.49537, 0.52031, 0.49306]]),
               'target': array([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
                       1,  1,  1,  1,  1,  1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1]),
               'DESCR': 'libras_move'}),
             ('thyroid_sick',
              {'data': array([[41.,  1.,  0., ...,  1.,  0.,  0.],
                      [23.,  1.,  0., ...,  0.,  0.,  0.],
                      [46.,  0.,  1., ...,  0.,  0.,  0.],
                      ...,
                      [74.,  1.,  0., ...,  0.,  0.,  0.],
                      [72.,  0.,  1., ...,  0.,  0.,  1.],
                      [64.,  1.,  0., ...,  0.,  0.,  0.]]),
               'target': array([-1, -1, -1, ..., -1, -1, -1]),
               'DESCR': 'thyroid_sick'}),
             ('coil_2000',
              {'data': array([[33.,  1.,  3., ...,  0.,  0.,  0.],
                      [37.,  1.,  2., ...,  0.,  0.,  0.],
                      [37.,  1.,  2., ...,  0.,  0.,  0.],
                      ...,
                      [36.,  1.,  2., ...,  0.,  1.,  0.],
                      [33.,  1.,  3., ...,  0.,  0.,  0.],
                      [ 8.,  1.,  2., ...,  0.,  0.,  0.]]),
               'target': array([-1, -1, -1, ..., -1, -1, -1]),
               'DESCR': 'coil_2000'}),
             ('arrhythmia',
              {'data': array([[ 75. ,   0. , 190. , ...,   2.9,  23.3,  49.4],
                      [ 56. ,   1. , 165. , ...,   2.1,  20.4,  38.8],
                      [ 54. ,   0. , 172. , ...,   3.4,  12.3,  49. ],
                      ...,
                      [ 36. ,   0. , 166. , ...,   1. , -44.2, -33.2],
                      [ 32. ,   1. , 155. , ...,   2.4,  25. ,  46.6],
                      [ 78. ,   1. , 160. , ...,   1.6,  21.3,  32.8]]),
               'target': array([-1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,
                      -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,
                      -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                       1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                       1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                       1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,
                      -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]),
               'DESCR': 'arrhythmia'}),
             ('solar_flare_m0',
              {'data': array([[0., 1., 0., ..., 1., 1., 0.],
                      [0., 0., 1., ..., 1., 1., 0.],
                      [0., 1., 0., ..., 1., 0., 1.],
                      ...,
                      [0., 1., 0., ..., 1., 0., 1.],
                      [0., 0., 0., ..., 1., 0., 1.],
                      [1., 0., 0., ..., 1., 0., 1.]]),
               'target': array([-1, -1, -1, ..., -1, -1, -1]),
               'DESCR': 'solar_flare_m0'}),
             ('oil',
              {'data': array([[1.000000e+00, 2.558000e+03, 1.506090e+03, ..., 3.324319e+04,
                       6.574000e+01, 7.950000e+00],
                      [2.000000e+00, 2.232500e+04, 7.911000e+01, ..., 5.157204e+04,
                       6.573000e+01, 6.260000e+00],
                      [3.000000e+00, 1.150000e+02, 1.449850e+03, ..., 3.169284e+04,
                       6.581000e+01, 7.840000e+00],
                      ...,
                      [2.020000e+02, 1.400000e+01, 2.514000e+01, ..., 2.153050e+03,
                       6.591000e+01, 6.120000e+00],
                      [2.030000e+02, 1.000000e+01, 9.600000e+01, ..., 2.421430e+03,
                       6.597000e+01, 6.320000e+00],
                      [2.040000e+02, 1.100000e+01, 7.730000e+00, ..., 3.782680e+03,
                       6.565000e+01, 6.260000e+00]]),
               'target': array([ 1, -1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1,
                      -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,
                       1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1,
                      -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1,  1,  1,  1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,
                      -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
                      -1, -1]),
               'DESCR': 'oil'}),
             ('car_eval_4',
              {'data': array([[0., 0., 0., ..., 0., 1., 0.],
                      [0., 0., 0., ..., 0., 0., 1.],
                      [0., 0., 0., ..., 1., 0., 0.],
                      ...,
                      [0., 1., 0., ..., 0., 1., 0.],
                      [0., 1., 0., ..., 0., 0., 1.],
                      [0., 1., 0., ..., 1., 0., 0.]]),
               'target': array([-1, -1, -1, ..., -1, -1,  1]),
               'DESCR': 'car_eval_4'}),
             ('wine_quality',
              {'data': array([[ 7.  ,  0.27,  0.36, ...,  3.  ,  0.45,  8.8 ],
                      [ 6.3 ,  0.3 ,  0.34, ...,  3.3 ,  0.49,  9.5 ],
                      [ 8.1 ,  0.28,  0.4 , ...,  3.26,  0.44, 10.1 ],
                      ...,
                      [ 6.5 ,  0.24,  0.19, ...,  2.99,  0.46,  9.4 ],
                      [ 5.5 ,  0.29,  0.3 , ...,  3.34,  0.38, 12.8 ],
                      [ 6.  ,  0.21,  0.38, ...,  3.26,  0.32, 11.8 ]]),
               'target': array([-1, -1, -1, ..., -1, -1, -1]),
               'DESCR': 'wine_quality'}),
             ('letter_img',
              {'data': array([[ 2.,  8.,  3., ...,  8.,  0.,  8.],
                      [ 5., 12.,  3., ...,  8.,  4., 10.],
                      [ 4., 11.,  6., ...,  7.,  3.,  9.],
                      ...,
                      [ 6.,  9.,  6., ..., 12.,  2.,  4.],
                      [ 2.,  3.,  4., ...,  9.,  5.,  8.],
                      [ 4.,  9.,  6., ...,  7.,  2.,  8.]]),
               'target': array([-1, -1, -1, ..., -1, -1, -1]),
               'DESCR': 'letter_img'}),
             ('yeast_me2',
              {'data': array([[0.58, 0.61, 0.47, ..., 0.  , 0.48, 0.22],
                      [0.43, 0.67, 0.48, ..., 0.  , 0.53, 0.22],
                      [0.64, 0.62, 0.49, ..., 0.  , 0.53, 0.22],
                      ...,
                      [0.67, 0.57, 0.36, ..., 0.  , 0.56, 0.22],
                      [0.43, 0.4 , 0.6 , ..., 0.  , 0.53, 0.39],
                      [0.65, 0.54, 0.54, ..., 0.  , 0.53, 0.22]]),
               'target': array([-1, -1, -1, ...,  1, -1, -1]),
               'DESCR': 'yeast_me2'}),
             ('webpage',
              {'data': array([[0., 0., 0., ..., 0., 0., 0.],
                      [0., 0., 0., ..., 0., 1., 0.],
                      [0., 0., 0., ..., 1., 0., 0.],
                      ...,
                      [0., 0., 0., ..., 1., 1., 0.],
                      [0., 0., 0., ..., 0., 0., 0.],
                      [0., 0., 0., ..., 0., 1., 0.]]),
               'target': array([-1, -1, -1, ...,  1,  1,  1]),
               'DESCR': 'webpage'}),
             ('ozone_level',
              {'data': array([[ 8.0000e-01,  1.8000e+00,  2.4000e+00, ...,  1.0330e+04,
                       -5.5000e+01,  0.0000e+00],
                      [ 2.8000e+00,  3.2000e+00,  3.3000e+00, ...,  1.0275e+04,
                       -5.5000e+01,  0.0000e+00],
                      [ 2.9000e+00,  2.8000e+00,  2.6000e+00, ...,  1.0235e+04,
                       -4.0000e+01,  0.0000e+00],
                      ...,
                      [ 8.0000e-01,  8.0000e-01,  1.2000e+00, ...,  1.0275e+04,
                       -3.5000e+01,  0.0000e+00],
                      [ 1.3000e+00,  9.0000e-01,  1.5000e+00, ...,  1.0245e+04,
                       -3.0000e+01,  5.0000e-02],
                      [ 1.5000e+00,  1.3000e+00,  1.8000e+00, ...,  1.0220e+04,
                       -2.5000e+01,  0.0000e+00]]),
               'target': array([-1, -1, -1, ..., -1, -1, -1]),
               'DESCR': 'ozone_level'}),
             ('mammography',
              {'data': array([[ 0.23001961,  5.0725783 , -0.27606055,  0.83244412, -0.37786573,
                        0.4803223 ],
                      [ 0.15549112, -0.16939038,  0.67065219, -0.85955255, -0.37786573,
                       -0.94572324],
                      [-0.78441482, -0.44365372,  5.6747053 , -0.85955255, -0.37786573,
                       -0.94572324],
                      ...,
                      [ 1.2049878 ,  1.7637238 , -0.50146835,  1.5624078 ,  6.4890725 ,
                        0.93129397],
                      [ 0.73664398, -0.22247361, -0.05065276,  1.5096647 ,  0.53926914,
                        1.3152293 ],
                      [ 0.17700275, -0.19150839, -0.50146835,  1.5788636 ,  7.750705  ,
                        1.5559507 ]]),
               'target': array([-1, -1, -1, ...,  1,  1,  1]),
               'DESCR': 'mammography'}),
             ('protein_homo',
              {'data': array([[ 52.  ,  32.69,   0.3 , ...,  -0.35,   0.26,   0.76],
                      [ 58.  ,  33.33,   0.  , ...,   1.16,   0.39,   0.73],
                      [ 77.  ,  27.27,  -0.91, ...,  -0.76,   0.26,   0.24],
                      ...,
                      [100.  ,  71.76,  41.92, ...,   3.41,   0.44,   0.78],
                      [ 85.65,  26.46,   1.85, ...,   2.88,   0.54,   0.77],
                      [ 87.5 ,  29.33,   5.84, ...,  -0.58,   0.16,   0.23]]),
               'target': array([-1, -1, -1, ...,  1, -1,  1]),
               'DESCR': 'protein_homo'}),
             ('abalone_19',
              {'data': array([[0.    , 0.    , 1.    , ..., 0.2245, 0.101 , 0.15  ],
                      [0.    , 0.    , 1.    , ..., 0.0995, 0.0485, 0.07  ],
                      [1.    , 0.    , 0.    , ..., 0.2565, 0.1415, 0.21  ],
                      ...,
                      [0.    , 0.    , 1.    , ..., 0.5255, 0.2875, 0.308 ],
                      [1.    , 0.    , 0.    , ..., 0.531 , 0.261 , 0.296 ],
                      [0.    , 0.    , 1.    , ..., 0.9455, 0.3765, 0.495 ]]),
               'target': array([-1, -1, -1, ..., -1, -1, -1]),
               'DESCR': 'abalone_19'})])
In [3]:
wine_quality = datasets["wine_quality"]
data, target = wine_quality["data"], wine_quality["target"]
data, target
Out[3]:
(array([[ 7.  ,  0.27,  0.36, ...,  3.  ,  0.45,  8.8 ],
        [ 6.3 ,  0.3 ,  0.34, ...,  3.3 ,  0.49,  9.5 ],
        [ 8.1 ,  0.28,  0.4 , ...,  3.26,  0.44, 10.1 ],
        ...,
        [ 6.5 ,  0.24,  0.19, ...,  2.99,  0.46,  9.4 ],
        [ 5.5 ,  0.29,  0.3 , ...,  3.34,  0.38, 12.8 ],
        [ 6.  ,  0.21,  0.38, ...,  3.26,  0.32, 11.8 ]]),
 array([-1, -1, -1, ..., -1, -1, -1]))
In [4]:
target = (target == 1).astype(int)
target
Out[4]:
array([0, 0, 0, ..., 0, 0, 0])
In [5]:
data.shape, target.shape
Out[5]:
((4898, 11), (4898,))
In [6]:
Counter(target)
Out[6]:
Counter({0: 4715, 1: 183})
In [7]:
4715 / 183
Out[7]:
25.76502732240437
In [8]:
columns = [
    "fixed_acidity",
    "volatile_acidity",
    "citric_acid",
    "residual_sugar",
    "chlorides",
    "free_sulfur_dioxide",
    "total_sulfur_dioxide",
    "density",
    "pH",
    "sulphates",
    "alcohol",
]
In [9]:
df = pd.DataFrame(data, columns=columns)
df.head()
Out[9]:
fixed_acidity volatile_acidity citric_acid residual_sugar chlorides free_sulfur_dioxide total_sulfur_dioxide density pH sulphates alcohol
0 7.0 0.27 0.36 20.7 0.045 45.0 170.0 1.0010 3.00 0.45 8.8
1 6.3 0.30 0.34 1.6 0.049 14.0 132.0 0.9940 3.30 0.49 9.5
2 8.1 0.28 0.40 6.9 0.050 30.0 97.0 0.9951 3.26 0.44 10.1
3 7.2 0.23 0.32 8.5 0.058 47.0 186.0 0.9956 3.19 0.40 9.9
4 7.2 0.23 0.32 8.5 0.058 47.0 186.0 0.9956 3.19 0.40 9.9
In [10]:
X_train, X_test, y_train, y_test = train_test_split(df, target, test_size=0.3, random_state=42, shuffle=True)
X_train.shape, X_test.shape, y_train.shape, y_test.shape
Out[10]:
((3428, 11), (1470, 11), (3428,), (1470,))
In [11]:
params = {
    "random_state": 42,
    "n_jobs": -1,
}

models = [
    LogisticRegression(
        max_iter=10_000,
        **params
    ),
    RandomForestClassifier(
        n_estimators=50,
        **params
    ),
    XGBClassifier(
        **params
    ),
    # TabPFNClassifier(),
]

model_names = [
    "LogisticRegression",
    "RandomForestClassifier",
    "XGBClassifier",
]
In [12]:
predictions = []

for model in tqdm(models):
    model.fit(X_train.values, y_train)
    prediction = model.predict(X_test.values)
    predictions.append(prediction)

predictions
100%|██████████| 3/3 [00:00<00:00,  3.71it/s]
Out[12]:
[array([0, 0, 0, ..., 0, 0, 0]),
 array([0, 0, 0, ..., 0, 0, 0]),
 array([0, 0, 0, ..., 0, 0, 0])]
In [13]:
metrics = [
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
]
In [14]:
summary = {}

for prediction, model_name in zip(predictions, model_names):
    print(model_name)
    print(Counter(prediction))
    model_summary = {}
    for metric in metrics:
        print(metric.__name__)
        score = metric(y_test, prediction)
        model_summary[metric.__name__] = score
        print(score)
    summary[f"{model_name} {Counter(prediction)}"] = model_summary
    print()

summary
LogisticRegression
Counter({0: 1470})
accuracy_score
0.9680272108843537
precision_score
0.0
recall_score
0.0
f1_score
0.0

RandomForestClassifier
Counter({0: 1460, 1: 10})
accuracy_score
0.9707482993197278
precision_score
0.7
recall_score
0.14893617021276595
f1_score
0.24561403508771928

XGBClassifier
Counter({0: 1452, 1: 18})
accuracy_score
0.9693877551020408
precision_score
0.5555555555555556
recall_score
0.2127659574468085
f1_score
0.3076923076923077
/home/karol/anaconda3/envs/xml/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
Out[14]:
{'LogisticRegression Counter({0: 1470})': {'accuracy_score': 0.9680272108843537,
  'precision_score': 0.0,
  'recall_score': 0.0,
  'f1_score': 0.0},
 'RandomForestClassifier Counter({0: 1460, 1: 10})': {'accuracy_score': 0.9707482993197278,
  'precision_score': 0.7,
  'recall_score': 0.14893617021276595,
  'f1_score': 0.24561403508771928},
 'XGBClassifier Counter({0: 1452, 1: 18})': {'accuracy_score': 0.9693877551020408,
  'precision_score': 0.5555555555555556,
  'recall_score': 0.2127659574468085,
  'f1_score': 0.3076923076923077}}
In [15]:
df_summary = pd.DataFrame(summary).T
df_summary
Out[15]:
accuracy_score precision_score recall_score f1_score
LogisticRegression Counter({0: 1470}) 0.968027 0.000000 0.000000 0.000000
RandomForestClassifier Counter({0: 1460, 1: 10}) 0.970748 0.700000 0.148936 0.245614
XGBClassifier Counter({0: 1452, 1: 18}) 0.969388 0.555556 0.212766 0.307692
In [16]:
fig = px.bar(df_summary, barmode="group", title="Comparison of metrics on wine_quality dataset")
# fig.write_html("metrics.html")
fig.show()

Lime

In [17]:
import lime
from lime import lime_tabular
In [18]:
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
    training_data=X_train.values,  
    feature_names=X_train.columns,
    mode="classification"
)
In [19]:
def check_lime(idxs):
    if isinstance(idxs, int):
        idxs = [idxs]

    for idx in idxs:
        for model in models:
            lime_explanation = lime_explainer.explain_instance(
                data_row=X_test.iloc[idx],
                predict_fn=lambda d: model.predict_proba(d)
            )
            
            print(f"Describing {model.__class__} for {idx} of label {y_test[idx]}")
            lime_explanation.show_in_notebook()

Check for label 0

In [20]:
check_lime(np.where(y_test == 0)[0][:3])
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 0 of label 0
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 0 of label 0
Describing <class 'xgboost.sklearn.XGBClassifier'> for 0 of label 0
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 1 of label 0
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 1 of label 0
Describing <class 'xgboost.sklearn.XGBClassifier'> for 1 of label 0
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 2 of label 0
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 2 of label 0
Describing <class 'xgboost.sklearn.XGBClassifier'> for 2 of label 0

By using lime we can see what features models consider as indicators of samples' label.

Check for label 1

In [21]:
check_lime(np.where(y_test == 1)[0][:3])
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 7 of label 1
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 7 of label 1
Describing <class 'xgboost.sklearn.XGBClassifier'> for 7 of label 1
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 44 of label 1
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 44 of label 1
Describing <class 'xgboost.sklearn.XGBClassifier'> for 44 of label 1
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 72 of label 1
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 72 of label 1
Describing <class 'xgboost.sklearn.XGBClassifier'> for 72 of label 1

We can see that some features are highly associated with either label. For example free_sulfur_dioxide for label 1 or chlorides with label 0. Changing this values can influence models' decisions.

Manipulate to get label 1

Get index of label 0

In [23]:
saved_idx = int(np.where(y_test == 1)[0][0])
check_lime(saved_idx)
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 7 of label 1
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 7 of label 1
Describing <class 'xgboost.sklearn.XGBClassifier'> for 7 of label 1

Models are sure that this data point is label 0

In [24]:
saved = X_test.iloc[saved_idx]
X_test.iloc[saved_idx][["free_sulfur_dioxide", "density", "alcohol"]] = 0
In [25]:
check_lime(saved_idx)
X_test.iloc[saved_idx] = saved
Describing <class 'sklearn.linear_model._logistic.LogisticRegression'> for 7 of label 1
Describing <class 'sklearn.ensemble._forest.RandomForestClassifier'> for 7 of label 1
Describing <class 'xgboost.sklearn.XGBClassifier'> for 7 of label 1

We were able to fool LogisticRegression to think it is label 1 with 98% probability.